--- external/gsn/models/model_utils.py	2023-03-09 18:12:06.857800424 +0000
+++ external_reference/gsn/models/model_utils.py	2023-03-09 18:07:58.345065235 +0000
@@ -1,6 +1,7 @@
 import torch
 from torch import nn
 import numpy as np
+from torch_utils import persistence
 
 
 def flatten_trajectories(data):
@@ -68,7 +69,7 @@
         if nerf_out_res:
             self.nerf_out_res = nerf_out_res
 
-
+@persistence.persistent_class
 class TrajectorySampler(nn.Module):
     """Trajectory sampler.
 
@@ -99,7 +100,8 @@
     def __init__(self, real_Rts, mode='sample', num_bins=10, alpha_activation='relu', jitter_range=0):
         super().__init__()
 
-        self.real_Rts = nn.Parameter(real_Rts, requires_grad=False)  # shape [n_trajectories, seq_len, 4, 4]
+        # CHANGED: make a buffer s.t. does not update in GAN updates
+        self.register_buffer('real_Rts', real_Rts)
         self.mode = mode
         self.num_bins = num_bins
         self.alpha_activation = alpha_activation
@@ -107,9 +109,10 @@
 
         # convert Rt matrices to camera pose matrices, then extract translation component
         # make sure Rts are float, since inverse doesn't work with FP16
-        self.real_trajectories = real_Rts.float().inverse()[:, :, :3, 3].contiguous()
+        real_trajectories = real_Rts.float().inverse()[:, :, :3, 3].contiguous()
         # shape [n_trajectories, seq_len, 3]
-        self.real_trajectories = nn.Parameter(self.real_trajectories, requires_grad=False)
+        # CHANGED: make a buffer s.t. does not update in GAN updates
+        self.register_buffer('real_trajectories', real_trajectories)
         self.seq_len = self.real_trajectories.shape[1]
 
         if mode == 'bin':
